# Copyright (c) HySoP 2011-2024
#
# This file is part of HySoP software.
# See "https://particle_methods.gricad-pages.univ-grenoble-alpes.fr/hysop-doc/"
# for further info.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from hysop.tools.htypes import first_not_None, to_tuple
from hysop.tools.sympy_utils import (
nabla,
partial,
subscript,
subscripts,
exponent,
exponents,
xsymbol,
get_derivative_variables,
)
import sympy as sm
from sympy.printing.str import StrPrinter, StrReprPrinter
from sympy.printing.latex import LatexPrinter
from packaging import version
if version.parse(sm.__version__) > version.parse("1.7"):
from sympy.printing.c import C99CodePrinter
else:
from sympy.printing.ccode import C99CodePrinter
[docs]
class BasePrinter:
[docs]
def print_Derivative(self, expr):
(bvar, pvar, vvar, lvar) = print_all_names(expr.args[0])
pvar = pvar
all_xvars = get_derivative_variables(expr)
xvars = tuple(set(all_xvars))
varpows = tuple(all_xvars.count(x) for x in xvars)
bxvars = tuple(print_name(x) for x in xvars)
pxvars = tuple(print_pretty_name(x) for x in xvars)
vxvars = tuple(print_var_name(x) for x in xvars)
lxvars = tuple(print_latex_name(x) for x in xvars)
return DifferentialStringFormatter.format_pd(
bvar, pvar, vvar, lvar, bxvars, pxvars, vxvars, lxvars, varpows=varpows
)
def _print(self, expr, **kwds):
try:
return super()._print(expr, **kwds)
except:
print
msg = "FATAL ERROR: {} failed to print expression {}."
msg = msg.format(type(self).__name__, expr)
print(msg)
print
raise
[docs]
class NamePrinter(BasePrinter, StrReprPrinter):
def _print(self, expr, **kwds):
if hasattr(expr, "name"):
return expr.name
elif hasattr(expr, "_name"):
return expr._name
return super()._print(expr, **kwds)
def _print_Derivative(self, expr):
return super().print_Derivative(expr)[0]
def _print_Add(self, expr):
return super()._print_Add(expr).replace(" ", "")
def _print_Mul(self, expr):
return super()._print_Mul(expr).replace(" ", "")
[docs]
def emptyPrinter(self, expr):
msg = "\n{} does not implement _print_{}(self, expr)."
msg += f"\nExpression is {expr}."
msg += "\nExpression type MRO is:"
msg += "\n *" + "\n *".join(t.__name__ for t in type(expr).__mro__)
msg = msg.format(self.__class__.__name__, expr.__class__.__name__)
raise NotImplementedError(msg)
[docs]
class PrettyNamePrinter(BasePrinter, StrPrinter):
def _print(self, expr, **kwds):
if hasattr(expr, "pretty_name"):
return expr.pretty_name
elif hasattr(expr, "_pretty_name"):
return expr._pretty_name
return super()._print(expr, **kwds)
def _print_Derivative(self, expr):
return super().print_Derivative(expr)[1]
[docs]
def emptyPrinter(self, expr):
msg = "\n{} does not implement _print_{}(self, expr)."
msg += f"\nExpression is {expr}."
msg += "\nExpression type MRO is:"
msg += "\n *" + "\n *".join(t.__name__ for t in type(expr).__mro__)
msg = msg.format(self.__class__.__name__, expr.__class__.__name__)
raise NotImplementedError(msg)
[docs]
class VarNamePrinter(BasePrinter, C99CodePrinter):
def _print(self, expr, **kwds):
if hasattr(expr, "var_name"):
return expr.var_name
elif hasattr(expr, "_var_name"):
return expr._var_name
return super()._print(expr, **kwds).replace(" ", "")
def _print_Derivative(self, expr):
return super().print_Derivative(expr)[2]
def _print_Add(self, expr):
s = super()._print_Add(expr)
s = s.replace(" + ", "_plus_").replace(" - ", "_minus_")
s = s.replace("+", "plus_").replace("-", "minus_")
return s
def _print_Mul(self, expr):
s = super()._print_Mul(expr)
s = s.replace(" * ", "_times_").replace("+", "plus_").replace("-", "minus_")
return s
[docs]
def emptyPrinter(self, expr):
msg = "\n{} does not implement _print_{}(self, expr)."
msg += f"\nExpression is {expr}."
msg += "\nExpression type MRO is:"
msg += "\n *" + "\n *".join(t.__name__ for t in type(expr).__mro__)
msg = msg.format(self.__class__.__name__, expr.__class__.__name__)
raise NotImplementedError(msg)
[docs]
class LatexNamePrinter(BasePrinter, LatexPrinter):
def _print(self, expr, **kwds):
if hasattr(expr, "latex_name"):
return expr.latex_name
elif hasattr(expr, "_latex_name"):
return expr._latex_name
return super()._print(expr, **kwds)
def _print_Derivative(self, expr):
return super().print_Derivative(expr)[3]
def _print_int(self, expr):
return str(expr)
[docs]
def emptyPrinter(self, expr):
msg = "\n{} does not implement _print_{}(self, expr)."
msg += f"\nExpression is {expr}."
msg += "\nExpression type MRO is:"
msg += "\n *" + "\n *".join(t.__name__ for t in type(expr).__mro__)
msg = msg.format(self.__class__.__name__, expr.__class__.__name__)
raise NotImplementedError(msg)
pbn = NamePrinter()
ppn = PrettyNamePrinter()
# pvn = VarNamePrinter()
pln = LatexNamePrinter()
[docs]
def print_name(expr):
return pbn.doprint(expr)
[docs]
def print_pretty_name(expr):
return ppn.doprint(expr)
[docs]
def print_var_name(expr):
return VarNamePrinter().doprint(expr)
[docs]
def print_latex_name(expr):
return pln.doprint(expr)
[docs]
def print_all_names(expr):
name = print_name(expr)
pretty_name = print_pretty_name(expr)
var_name = print_var_name(expr)
latex_name = print_latex_name(expr)
return (name, pretty_name, var_name, latex_name)
[docs]
def to_str(*args):
if len(args) == 1:
args = to_tuple(args[0])
def _to_str(x):
return str(x)
return tuple(_to_str(y) for y in args)
# exponents formatting functions
[docs]
def bexp_fn(x):
return f"^{x}" if (x > 1) else ""
pexp_fn = lambda x, sep=",": exponents(x, sep=sep) if (x > 1) else ""
[docs]
def vexp_fn(x):
return f"e{x}" if (x > 1) else ""
[docs]
def lexp_fn(x):
return f"^<LBRACKET>{x}<RBRACKET>" if (x > 1) else ""
# powers formatting functions
[docs]
def bpow_fn(x):
return f"**{x}" if (x > 1) else ""
ppow_fn = lambda x, sep=",": exponents(x, sep=sep) if (x > 1) else ""
[docs]
def vpow_fn(x):
return f"p{x}" if (x > 1) else ""
[docs]
def lpow_fn(x):
return f"^<LBRACKET>{x}<RBRACKET>" if (x > 1) else ""
# subcripts formatting functions
[docs]
def bsub_fn(x):
return f"_{x}" if (x is not None) else ""
psub_fn = lambda x, sep=",": subscripts(x, sep=sep) if (x is not None) else ""
[docs]
def vsub_fn(x):
return f"s{x}" if (x is not None) else ""
[docs]
def lsub_fn(x):
return f"_<LBRACKET>{x}<RBRACKET>" if (x is not None) else ""
# components formatting functions
[docs]
def bcomp_fn(x):
return ",".join(to_str(x)) if (x is not None) else ""
pcomp_fn = lambda x, sep=",": subscripts(x, sep=sep) if (x is not None) else ""
[docs]
def vcomp_fn(x):
return "_" + "_".join(to_str(x)) if (x is not None) else ""
[docs]
def lcomp_fn(x):
return (
"_<LBRACKET>{}<RBRACKET>".format(",".join(to_str(x))) if (x is not None) else ""
)
# join formatting functions
[docs]
def bjoin_fn(x):
return "_".join(to_str(x)) if (x is not None) else ""
[docs]
def pjoin_fn(x):
return "".join(to_str(x)) if (x is not None) else ""
[docs]
def vjoin_fn(x):
return "_".join(to_str(x)) if (x is not None) else ""
[docs]
def ljoin_fn(x):
return "".join(to_str(x)) if (x is not None) else ""
# divide formatting functions
[docs]
def bdivide_fn(x, y):
return f"{x}/{y}"
[docs]
def pdivide_fn(x, y):
return "{}/{}".format(*to_str(x, y))
[docs]
def vdivide_fn(x, y):
return f"{x}__{y}"
[docs]
def ldivide_fn(x, y):
return rf"\dfrac<LBRACKET>{x}<RBRACKET><LBRACKET>{y}<RBRACKET>"
if __name__ == "__main__":
def _print(*args, **kwds):
if isinstance(args[0], tuple):
assert len(args) == 1
args = args[0]
if ("multiline" in kwds) and (kwds["multiline"] is True):
for a in args:
print(a)
else:
print(", ".join(a for a in args))
print
bvar, pvar, vvar, lvar = (
"Fext",
"Fₑₓₜ",
"Fext",
"<LBRACKET>F_<LBRACKET>ext<RBRACKET><RBRACKET>",
)
_print(DifferentialStringFormatter.return_names(bvar, pvar, vvar, lvar))
print
_print(
DifferentialStringFormatter.format_partial_name(bvar, pvar, vvar, lvar, dpow=0)
)
_print(
DifferentialStringFormatter.format_partial_name(bvar, pvar, vvar, lvar, dpow=1)
)
_print(
DifferentialStringFormatter.format_partial_name(bvar, pvar, vvar, lvar, dpow=2)
)
_print(
DifferentialStringFormatter.format_partial_name(
bvar, pvar, vvar, lvar, dpow=3, components=0
)
)
_print(
DifferentialStringFormatter.format_partial_name(
bvar, pvar, vvar, lvar, dpow=4, components=(0, 2)
)
)
print
bvar, pvar, vvar, lvar = ("x",) * 4
_print(
DifferentialStringFormatter.format_partial_name(
bvar, pvar, vvar, lvar, varpow=1
)
)
_print(
DifferentialStringFormatter.format_partial_name(
bvar, pvar, vvar, lvar, varpow=2
)
)
_print(
DifferentialStringFormatter.format_partial_name(
bvar, pvar, vvar, lvar, varpow=3, components=0
)
)
_print(
DifferentialStringFormatter.format_partial_name(
bvar, pvar, vvar, lvar, varpow=4, components=(0, 2)
)
)
print
bvar, pvar, vvar, lvar = (("x", "y"),) * 4
try:
_print(
DifferentialStringFormatter.format_partial_names(
bvar, pvar, vvar, lvar, varpows=(0, 0)
)
)
raise RuntimeError()
except AssertionError:
pass
_print(
DifferentialStringFormatter.format_partial_names(
bvar, pvar, vvar, lvar, varpows=(0, 1)
)
)
_print(
DifferentialStringFormatter.format_partial_names(
bvar, pvar, vvar, lvar, varpows=(1, 0)
)
)
_print(
DifferentialStringFormatter.format_partial_names(
bvar, pvar, vvar, lvar, varpows=(1, 1)
)
)
_print(
DifferentialStringFormatter.format_partial_names(
bvar, pvar, vvar, lvar, varpows=(1, 2)
)
)
_print(
DifferentialStringFormatter.format_partial_names(
bvar, pvar, vvar, lvar, varpows=(2, 2)
)
)
_print(
DifferentialStringFormatter.format_partial_names(
bvar, pvar, vvar, lvar, varpows=(2, 2), components=(0, 1)
)
)
_print(
DifferentialStringFormatter.format_partial_names(
bvar, pvar, vvar, lvar, varpows=(2, 2), components=((0, 1), (1, 0))
)
)
print
bvar, pvar, vvar, lvar = (
"Fext",
"Fₑₓₜ",
"Fext",
"<LBRACKET>F_<LBRACKET>ext<RBRACKET><RBRACKET>",
)
bxvars, pxvars, vxvars, lxvars = (("x", "y"),) * 4
_print(DifferentialStringFormatter.format_pd(bvar, pvar, vvar, lvar))
_print(DifferentialStringFormatter.format_pd(bvar, pvar, vvar, lvar, varpows=2))
_print(
DifferentialStringFormatter.format_pd(
bvar, pvar, vvar, lvar, bxvars, pxvars, vxvars, lxvars, varpows=(1, 0)
)
)
_print(
DifferentialStringFormatter.format_pd(
bvar, pvar, vvar, lvar, bxvars, pxvars, vxvars, lxvars, varpows=(0, 1)
)
)
_print(
DifferentialStringFormatter.format_pd(
bvar, pvar, vvar, lvar, bxvars, pxvars, vxvars, lxvars, varpows=(1, 1)
)
)
_print(
DifferentialStringFormatter.format_pd(
bvar, pvar, vvar, lvar, bxvars, pxvars, vxvars, lxvars, varpows=(5, 2)
)
)
print
bxvars, pxvars, vxvars, lxvars = (("x",) * 5,) * 4
varpows = (1,) * 5
xvars_components = tuple(range(5))
var_components = (0, 4, 3, 2)
_print(
DifferentialStringFormatter.format_pd(
bvar,
pvar,
vvar,
lvar,
bxvars,
pxvars,
vxvars,
lxvars,
varpows=varpows,
xvars_components=xvars_components,
var_components=var_components,
),
multiline=True,
)